#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

#

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys
import time
import tensorflow as tf  # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers
from official.utils.logs import hooks_helper

_CSV_COLUMNS = [
    'age', 'workclass', 'fnlwgt', 'education', 'education_num',
    'marital_status', 'occupation', 'relationship', 'race', 'gender',
    'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
    'income_bracket'
]

_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''],
                        [''], [''], [0], [0], [0], [''], ['']]

_NUM_EXAMPLES = {
    'train': 32561,
    'validation': 16281,
}


LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}


def build_model_columns():
    """Builds a set of wide and deep feature columns."""
    # Continuous columns
    age = tf.feature_column.numeric_column('age')
    education_num = tf.feature_column.numeric_column('education_num')
    capital_gain = tf.feature_column.numeric_column('capital_gain')
    capital_loss = tf.feature_column.numeric_column('capital_loss')
    hours_per_week = tf.feature_column.numeric_column('hours_per_week')

    education = tf.feature_column.categorical_column_with_vocabulary_list(
        'education', [
            'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
            'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
            '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])

    marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
        'marital_status', [
            'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
            'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])

    relationship = tf.feature_column.categorical_column_with_vocabulary_list(
        'relationship', [
            'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
            'Other-relative'])

    workclass = tf.feature_column.categorical_column_with_vocabulary_list(
        'workclass', [
            'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
            'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])

    # To show an example of hashing:
    occupation = tf.feature_column.categorical_column_with_hash_bucket(
        'occupation', hash_bucket_size=1000)

    # Transformations.
    age_buckets = tf.feature_column.bucketized_column(
        age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

    # Wide columns and deep columns.
    base_columns = [
        education, marital_status, relationship, workclass, occupation,
        age_buckets,
    ]

    crossed_columns = [
        tf.feature_column.crossed_column(
            ['education', 'occupation'], hash_bucket_size=1000),
        tf.feature_column.crossed_column(
            [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
    ]

    wide_columns = base_columns + crossed_columns

    deep_columns = [
        age,
        education_num,
        capital_gain,
        capital_loss,
        hours_per_week,
        tf.feature_column.indicator_column(workclass),
        tf.feature_column.indicator_column(education),
        tf.feature_column.indicator_column(marital_status),
        tf.feature_column.indicator_column(relationship),
        # To show an example of embedding
        tf.feature_column.embedding_column(occupation, dimension=8),
    ]

    return wide_columns, deep_columns


def build_estimator(model_dir, model_type):
    """Build an estimator appropriate for the given model type."""
    wide_columns, deep_columns = build_model_columns()
    hidden_units = [100, 75, 50, 25]

    # Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
    # trains faster than GPU for this model.
    run_config = tf.estimator.RunConfig().replace(
        session_config=tf.compat.v1.ConfigProto(device_count={'GPU': 0}))

    if model_type == 'wide':
        return tf.estimator.LinearClassifier(
            model_dir=model_dir,
            feature_columns=wide_columns,
            config=run_config, loss_reduction=tf.compat.v1.losses.Reduction.SUM)
    elif model_type == 'deep':
        return tf.estimator.DNNClassifier(
            model_dir=model_dir,
            feature_columns=deep_columns,
            hidden_units=hidden_units,
            config=run_config, loss_reduction=tf.compat.v1.losses.Reduction.SUM)
    else:
        return tf.estimator.DNNLinearCombinedClassifier(
            model_dir=model_dir,
            linear_feature_columns=wide_columns,
            dnn_feature_columns=deep_columns,
            dnn_hidden_units=hidden_units,
            config=run_config, loss_reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
            #config=run_config, loss_reduction=tf.compat.v1.losses.Reduction.SUM)


def input_fn(data_file, num_epochs, shuffle, batch_size):
    """Generate an input function for the Estimator."""
    assert tf.io.gfile.exists(data_file), (
        '%s not found. Please make sure you have run data_download.py and '
        'set the --data_dir argument to the correct path.' % data_file)

    def parse_csv(value):
        print('Parsing', data_file)
        columns = tf.io.decode_csv(records=value, record_defaults=_CSV_COLUMN_DEFAULTS)
        features = dict(zip(_CSV_COLUMNS, columns))
        labels = features.pop('income_bracket')
        return features, tf.equal(labels, '>50K')

    # Extract lines from input files using the Dataset API.
    dataset = tf.data.TextLineDataset(data_file)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])

    dataset = dataset.map(parse_csv, num_parallel_calls=5)

    # We call repeat after shuffling, rather than before, to prevent separate
    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size)
    return dataset


def main(argv):
    parser = WideDeepArgParser()
    flags = parser.parse_args(args=argv[1:])

    # Clean up the model directory if present
    # shutil.rmtree(flags.model_dir, ignore_errors=True)
    model = build_estimator(flags.model_dir, flags.model_type)

    train_file = os.path.join(flags.data_dir, 'adult.data')
    test_file = os.path.join(flags.data_dir, 'adult.test')

    num_records = sum(1 for line in open(test_file))

    # Train and evaluate the model every `flags.epochs_between_evals` epochs.
    def train_input_fn():
        return input_fn(
            train_file, flags.epochs_between_evals, True, flags.batch_size)

    def eval_input_fn():
        return input_fn(test_file, 1, False, flags.batch_size)

    loss_prefix = LOSS_PREFIX.get(flags.model_type, '')
    train_hooks = hooks_helper.get_train_hooks(
        flags.hooks, batch_size=flags.batch_size,
        tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
                        'loss': loss_prefix + 'head/weighted_loss/Sum'})

    inference_start = time.time()
    # Train and evaluate the model every `flags.epochs_between_evals` epochs.
    results = model.evaluate(input_fn=eval_input_fn)

    # Display evaluation metrics
    print('-' * 60)

    for key in sorted(results):
        print('%s: %s' % (key, results[key]))
    main_end = time.time()
    E2Eduration = main_end - main_start
    print ('End-to-End duration is %s', E2Eduration)
    evaluate_duration = main_end - inference_start
    print ('Evaluation duration is %s', evaluate_duration)

    if flags.batch_size == 1:
        print('Latency is: %s', E2Eduration / num_records)
    else:
        print('Throughput is: %s', num_records / evaluate_duration)


class WideDeepArgParser(argparse.ArgumentParser):
    """Argument parser for running the wide deep model."""

    def __init__(self):
        super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()])
        self.add_argument(
            '--model_type', '-mt', type=str, default='wide_deep',
            choices=['wide', 'deep', 'wide_deep'],
            help='[default %(default)s] Valid model types:'
                 ' wide, deep, wide_deep.',
            metavar='<MT>')
        self.set_defaults(
            data_dir='/tmp/census_data',
            model_dir='/tmp/census_model',
            train_epochs=40,
            epochs_between_evals=2,
            batch_size=40)


if __name__ == '__main__':
    main_start = time.time()
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
    main(argv=sys.argv)
